{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 0. 导入模块"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "from __future__ import print_function\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pickle\n",
    "import os\n",
    "import argparse\n",
    "import sys\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "FLAGS = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.导入数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting ../data/mnist/input_data/train-images-idx3-ubyte.gz\n",
      "Extracting ../data/mnist/input_data/train-labels-idx1-ubyte.gz\n",
      "Extracting ../data/mnist/input_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting ../data/mnist/input_data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "data_dir = '../data/mnist/input_data/'\n",
    "mnist = input_data.read_data_sets(data_dir, one_hot=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((55000, 784), (55000, 10))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist.train.images.shape,mnist.train.labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((10000, 784), (10000, 10))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mnist.test.images.shape,mnist.test.labels.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 构造计算图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.placeholder(tf.float32,[None,784])\n",
    "y = tf.placeholder(tf.float32,[None,10])\n",
    "init_lr = tf.placeholder(tf.float32)\n",
    "keep_prob = tf.placeholder(tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_image = tf.reshape(x,[-1,28,28,1],name='x_image')\n",
    "#定义3个stage,每个stage两层卷积，一层pooling\n",
    "#32个（3，3）卷积核\n",
    "#第1个stage\n",
    "conv1_1 = tf.layers.conv2d(x_image,\n",
    "                           32,\n",
    "                           (3,3),\n",
    "                           padding='same',\n",
    "                           activation=tf.nn.relu,\n",
    "                           name='conv1_1')\n",
    "conv1_2 = tf.layers.conv2d(conv1_1,\n",
    "                           32,\n",
    "                           (3,3),\n",
    "                           padding='same',\n",
    "                           activation=tf.nn.relu,\n",
    "                           name='conv1_2') \n",
    "pooling1 = tf.layers.max_pooling2d(conv1_2,\n",
    "                                   (2,2),\n",
    "                                   (2,2),\n",
    "                                   name='pooling1') \n",
    "#第2个stage\n",
    "conv2_1 = tf.layers.conv2d(pooling1,\n",
    "                           32,\n",
    "                           (3,3),\n",
    "                           padding='same',\n",
    "                           activation=tf.nn.relu,\n",
    "                           name='conv2_1')\n",
    "conv2_2 = tf.layers.conv2d(conv2_1,\n",
    "                           32,\n",
    "                           (3,3),\n",
    "                           padding='same',\n",
    "                           activation=tf.nn.relu,\n",
    "                           name='conv2_2') \n",
    "pooling2 = tf.layers.max_pooling2d(conv2_2,\n",
    "                                   (2,2),\n",
    "                                   (2,2),\n",
    "                                   name='pooling2') \n",
    "#第3个stage\n",
    "conv3_1 = tf.layers.conv2d(pooling2,\n",
    "                           32,\n",
    "                           (3,3),\n",
    "                           padding='same',\n",
    "                           activation=tf.nn.relu,\n",
    "                           name='conv3_1')\n",
    "conv3_2 = tf.layers.conv2d(conv3_1,\n",
    "                           32,\n",
    "                           (3,3),\n",
    "                           padding='same',\n",
    "                           activation=tf.nn.relu,\n",
    "                           name='conv3_2') \n",
    "pooling3 = tf.layers.max_pooling2d(conv3_2,\n",
    "                                   (2,2),\n",
    "                                   (2,2),\n",
    "                                   name='pooling3') \n",
    "#全连接层\n",
    "with tf.name_scope('fc1'):\n",
    "    flatten = tf.layers.flatten(pooling3)\n",
    "    fc1 = tf.layers.dense(flatten,1024,activation=tf.nn.relu)\n",
    "with tf.name_scope('dropout'):\n",
    "    fc1_drop = tf.nn.dropout(fc1,keep_prob)\n",
    "with tf.name_scope('fc2'):\n",
    "    y_ = tf.layers.dense(fc1_drop,10,activation=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(TensorShape([Dimension(None), Dimension(10)]),\n",
       " TensorShape([Dimension(None), Dimension(10)]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_.shape,y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-8-64282ee2ecc1>:11: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "\n",
      "Future major versions of TensorFlow will allow gradients to flow\n",
      "into the labels input on backprop by default.\n",
      "\n",
      "See tf.nn.softmax_cross_entropy_with_logits_v2.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#定义learning_rate衰减方式为指数衰减\n",
    "epoch_steps = tf.to_int64(tf.div(60000, tf.shape(x)[0]))\n",
    "global_step = tf.train.get_or_create_global_step()\n",
    "current_epoch = global_step//epoch_steps\n",
    "decay_times = current_epoch \n",
    "learning_rate = tf.multiply(init_lr, \n",
    "                            tf.pow(0.65, tf.to_float(decay_times)))\n",
    "#计算loss\n",
    "cross_entropy = tf.reduce_mean(\n",
    "    tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_))\n",
    "\n",
    "l2_loss = tf.add_n( [tf.nn.l2_loss(w) for w in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)] )\n",
    "total_loss = cross_entropy + 7e-5*l2_loss\n",
    "train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss,global_step=global_step)\n",
    "#计算accuracy\n",
    "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 执行graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-11-9c4252095d8a>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      7\u001b[0m         _, loss,l2_loss_value,total_loss_value,lr_value = sess.run(\n\u001b[0;32m      8\u001b[0m                    \u001b[1;33m[\u001b[0m\u001b[0mtrain_op\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0ml2_loss\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mtotal_loss\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 9\u001b[1;33m                    feed_dict={x:batch_xs, y: batch_ys, init_lr:0.01, keep_prob:0.5})\n\u001b[0m\u001b[0;32m     10\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     11\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m%\u001b[0m \u001b[1;36m100\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Program Files\\Anaconda\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m    893\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    894\u001b[0m       result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 895\u001b[1;33m                          run_metadata_ptr)\n\u001b[0m\u001b[0;32m    896\u001b[0m       \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    897\u001b[0m         \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Program Files\\Anaconda\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m   1126\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1127\u001b[0m       results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1128\u001b[1;33m                              feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m   1129\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1130\u001b[0m       \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Program Files\\Anaconda\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m   1342\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1343\u001b[0m       return self._do_call(_run_fn, self._session, feeds, fetches, targets,\n\u001b[1;32m-> 1344\u001b[1;33m                            options, run_metadata)\n\u001b[0m\u001b[0;32m   1345\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1346\u001b[0m       \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Program Files\\Anaconda\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m   1348\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1349\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1350\u001b[1;33m       \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1351\u001b[0m     \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1352\u001b[0m       \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\Program Files\\Anaconda\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(session, feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m   1327\u001b[0m           return tf_session.TF_Run(session, options,\n\u001b[0;32m   1328\u001b[0m                                    \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1329\u001b[1;33m                                    status, run_metadata)\n\u001b[0m\u001b[0;32m   1330\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1331\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msession\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "init_op = tf.global_variables_initializer()\n",
    "# Train&Test\n",
    "with tf.Session() as sess:\n",
    "    sess.run(init_op)\n",
    "    for step in range(3000):\n",
    "        batch_xs, batch_ys = mnist.train.next_batch(100)\n",
    "        _, loss,l2_loss_value,total_loss_value,lr_value = sess.run(\n",
    "                   [train_op,cross_entropy,l2_loss,total_loss,learning_rate], \n",
    "                   feed_dict={x:batch_xs, y: batch_ys, init_lr:0.01, keep_prob:0.5})\n",
    "\n",
    "        if (step+1) % 100 == 0:\n",
    "            print('step%d, entropy loss:%f, l2_loss:%f, total loss:%f, learning_rate:%f' % \n",
    "                 (step+1, loss, l2_loss_value, total_loss_value, lr_value))\n",
    "            train_acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys, keep_prob:0.5})\n",
    "            print('train_acc:%f' % train_acc)\n",
    "        if (step+1) % 1000 == 0:   \n",
    "            test_acc = sess.run(accuracy, feed_dict={x: mnist.test.images,\n",
    "                                                     y: mnist.test.labels, keep_prob:0.5})\n",
    "            print('test_acc:%f' % test_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
